package com.icesoft.core.web.suppose.safehttp.request;

import com.fasterxml.jackson.core.type.TypeReference;
import com.icesoft.core.common.util.JsonUtil;
import com.icesoft.core.web.helper.RequestHold;
import com.icesoft.core.web.suppose.safehttp.filter.SafeRequestCryptoService;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.StringUtils;

import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.util.*;

/**
 * ContentCachingRequestWrapper
 */
@Slf4j
public class SafeHttpServletRequest extends HttpServletRequestWrapper {
    private TreeMap<String, String> paramMap;
    private HttpServletRequest request;
    private ServletInputStream inputStream;
    private BufferedReader reader;
    private String body;
    private SafeRequestCryptoService safeRequestCryptoService;
    public static final String ATTR_NAME = "SafeHttpServletRequest";

    private boolean isGetMethod;

    public SafeHttpServletRequest(SafeRequestCryptoService safeRequestCryptoService, HttpServletRequest request) {
        super(request);
        this.request = request;
        this.safeRequestCryptoService = safeRequestCryptoService;
        isGetMethod = "GET".equals(request.getMethod());
        request.setAttribute(ATTR_NAME, this);
    }

    public String getBody() throws IOException {
        if (body != null) {
            return body;
        }
        byte[] data = IOUtils.toByteArray(request.getInputStream());
        inputStream = new RequestCachingInputStream(data);
        body = new String(data, StandardCharsets.UTF_8);
        return body;
    }

    public TreeMap<String, String> resolveParam()
            throws IOException {
        if (paramMap != null) {
            return paramMap;
        }
        if (isGetMethod) {
            paramMap = new TreeMap<>(RequestHold.getRequestParamMap(request));
            return paramMap;
        }
        String paramStr = getBody();
        if (StringUtils.isBlank(paramStr)) {
            paramMap = new TreeMap<>(RequestHold.getRequestParamMap(request));
            return paramMap;
        }
        log.trace("body流数据：{}", paramStr);
        if (log.isTraceEnabled()) {
            log.trace("解密类: {}", safeRequestCryptoService.getClass());
        }
        paramStr = safeRequestCryptoService.decryptReqBody(paramStr);
        inputStream = new RequestCachingInputStream(paramStr.getBytes(StandardCharsets.UTF_8));
        log.trace("请求json：{}", paramStr);
        if (StringUtils.isNotBlank(paramStr)) {
            paramMap = JsonUtil.toObject(paramStr, new TypeReference<TreeMap<String, String>>() {
            });
        } else {
            paramMap = new TreeMap<>();
        }
        return paramMap;
    }

    @Override
    public BufferedReader getReader() throws IOException {
        if (reader == null && inputStream != null) {
            reader = new BufferedReader(new InputStreamReader(inputStream, getCharacterEncoding()));
        }
        return super.getReader();
    }

    @Override
    public ServletInputStream getInputStream() throws IOException {
        if (inputStream != null) {
            return inputStream;
        }
        return super.getInputStream();
    }

    @Override
    public String getParameter(String name) {
        if (paramMap != null) {
            return paramMap.get(name);
        }
        return super.getParameter(name);
    }

    @Override
    public String[] getParameterValues(String name) {
        if (paramMap != null) {
            return new String[]{paramMap.get(name)};
        }
        return super.getParameterValues(name);
    }

    @Override
    public Map<String, String[]> getParameterMap() {
        if (paramMap != null) {
            Map<String, String[]> map = new HashMap<>();
            for (Map.Entry<String, String> entry : paramMap.entrySet()) {
                map.put(entry.getKey(), new String[]{entry.getValue()});
            }
            return map;
        }
        return super.getParameterMap();
    }

    @Override
    public Enumeration<String> getParameterNames() {
        if (paramMap != null) {
            return Collections.enumeration(paramMap.keySet());
        }
        return super.getParameterNames();
    }

    private static class RequestCachingInputStream extends ServletInputStream {

        private final ByteArrayInputStream inputStream;

        public RequestCachingInputStream(byte[] bytes) {
            inputStream = new ByteArrayInputStream(bytes);
        }

        @Override
        public int read() throws IOException {
            return inputStream.read();
        }

        @Override
        public boolean isFinished() {
            return inputStream.available() == 0;
        }

        @Override
        public boolean isReady() {
            return true;
        }

        @Override
        public void setReadListener(ReadListener readlistener) {
        }

    }

}
